{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n",
      "/data/users/fyx/.local/python3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
      "  return f(*args, **kwds)\n",
      "/data/users/fyx/.local/python3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
      "  return f(*args, **kwds)\n"
     ]
    }
   ],
   "source": [
    "import keras\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matchzoo as mz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_pack = mz.datasets.wiki_qa.load_data('train', task='ranking')\n",
    "valid_pack = mz.datasets.wiki_qa.load_data('dev', task='ranking', filter=True)\n",
    "predict_pack = mz.datasets.wiki_qa.load_data('test', task='ranking', filter=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 2118/2118 [00:00<00:00, 8439.76it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 18841/18841 [00:03<00:00, 4712.19it/s]\n",
      "Processing text_right with append: 100%|██████████| 18841/18841 [00:00<00:00, 776436.02it/s]\n",
      "Building FrequencyFilterUnit from a datapack.: 100%|██████████| 18841/18841 [00:00<00:00, 130731.35it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 70052.12it/s] \n",
      "Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 594972.60it/s]\n",
      "Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 687411.98it/s]\n",
      "Building VocabularyUnit from a datapack.: 100%|██████████| 404415/404415 [00:00<00:00, 2737950.04it/s]\n",
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 2118/2118 [00:00<00:00, 8895.89it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 18841/18841 [00:03<00:00, 4797.60it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 117194.04it/s]\n",
      "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 153455.45it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 118119.83it/s]\n",
      "Processing length_left with len: 100%|██████████| 2118/2118 [00:00<00:00, 595930.49it/s]\n",
      "Processing length_right with len: 100%|██████████| 18841/18841 [00:00<00:00, 710150.90it/s]\n",
      "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 118891.00it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 82137.05it/s]\n"
     ]
    }
   ],
   "source": [
    "preprocessor = mz.preprocessors.BasicPreprocessor(fixed_length_left=10, fixed_length_right=100, remove_stop_words=False)\n",
    "train_pack_processed = preprocessor.fit_transform(train_pack)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 122/122 [00:00<00:00, 8023.85it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 1115/1115 [00:00<00:00, 4869.51it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 118055.46it/s]\n",
      "Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 102690.16it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 117960.17it/s]\n",
      "Processing length_left with len: 100%|██████████| 122/122 [00:00<00:00, 161983.25it/s]\n",
      "Processing length_right with len: 100%|██████████| 1115/1115 [00:00<00:00, 514256.54it/s]\n",
      "Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 70618.97it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 62619.15it/s]\n",
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 237/237 [00:00<00:00, 9101.69it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 2300/2300 [00:00<00:00, 4820.53it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 127858.17it/s]\n",
      "Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 141966.59it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 130761.09it/s]\n",
      "Processing length_left with len: 100%|██████████| 237/237 [00:00<00:00, 254102.77it/s]\n",
      "Processing length_right with len: 100%|██████████| 2300/2300 [00:00<00:00, 481694.67it/s]\n",
      "Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 92418.19it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 15213.65it/s]\n"
     ]
    }
   ],
   "source": [
    "valid_pack_processed = preprocessor.transform(valid_pack)\n",
    "predict_pack_processed = preprocessor.transform(predict_pack)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "ranking_task = mz.tasks.Ranking(loss=mz.losses.RankCrossEntropyLoss(num_neg=4))\n",
    "ranking_task.metrics = [\n",
    "    mz.metrics.NormalizedDiscountedCumulativeGain(k=3),\n",
    "    mz.metrics.NormalizedDiscountedCumulativeGain(k=5),\n",
    "    mz.metrics.MeanAveragePrecision()\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parameter \"name\" set to DRMM.\n",
      "Parameter \"embedding_trainable\" set to True.\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "text_left (InputLayer)          (None, 10)           0                                            \n",
      "__________________________________________________________________________________________________\n",
      "embedding (Embedding)           (None, 10, 300)      5002200     text_left[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "dense_1 (Dense)                 (None, 10, 1)        300         embedding[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "match_histogram (InputLayer)    (None, 10, 30)       0                                            \n",
      "__________________________________________________________________________________________________\n",
      "attention_mask (Lambda)         (None, 10, 1)        0           dense_1[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dense_2 (Dense)                 (None, 10, 10)       310         match_histogram[0][0]            \n",
      "__________________________________________________________________________________________________\n",
      "attention_probs (Lambda)        (None, 10, 1)        0           attention_mask[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "dense_3 (Dense)                 (None, 10, 1)        11          dense_2[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dot_1 (Dot)                     (None, 1, 1)         0           attention_probs[0][0]            \n",
      "                                                                 dense_3[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "flatten_1 (Flatten)             (None, 1)            0           dot_1[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "dense_4 (Dense)                 (None, 1)            2           flatten_1[0][0]                  \n",
      "==================================================================================================\n",
      "Total params: 5,002,823\n",
      "Trainable params: 5,002,823\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "bin_size = 30\n",
    "model = mz.models.DRMM()\n",
    "model.params['input_shapes'] = [[10,], [10, bin_size,]]\n",
    "model.params['task'] = ranking_task\n",
    "model.params['mask_value'] = 0\n",
    "model.params['embedding_input_dim'] = preprocessor.context['vocab_size']\n",
    "model.params['embedding_output_dim'] = 300\n",
    "model.params['mlp_num_layers'] = 1\n",
    "model.params['mlp_num_units'] = 10\n",
    "model.params['mlp_num_fan_out'] = 1\n",
    "model.params['mlp_activation_func'] = 'tanh'\n",
    "model.params['optimizer'] = 'adadelta'\n",
    "model.guess_and_fill_missing_params()\n",
    "model.build()\n",
    "model.compile()\n",
    "model.backend.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "glove_embedding = mz.datasets.embeddings.load_glove_embedding(dimension=300)\n",
    "embedding_matrix = glove_embedding.build_matrix(preprocessor.context['vocab_unit'].state['term_index'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# normalize the word embedding for fast histogram generating.\n",
    "l2_norm = np.sqrt((embedding_matrix*embedding_matrix).sum(axis=1))\n",
    "embedding_matrix = embedding_matrix / l2_norm[:, np.newaxis]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.load_embedding_matrix(embedding_matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_generator = mz.HistogramDataGenerator(data_pack=predict_pack_processed,\n",
    "                                           embedding_matrix=embedding_matrix,\n",
    "                                           bin_size=bin_size, \n",
    "                                           hist_mode='LCH')\n",
    "pred_x, pred_y = pred_generator[:]\n",
    "evaluate = mz.callbacks.EvaluateAllMetrics(model, \n",
    "                                           x=pred_x, \n",
    "                                           y=pred_y, \n",
    "                                           once_every=1, \n",
    "                                           batch_size=len(pred_y),\n",
    "                                           model_save_path='./drmm_pretrained_model/'\n",
    "                                          )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "102"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_generator = mz.HistogramPairDataGenerator(train_pack_processed, embedding_matrix, bin_size, 'LCH',\n",
    "                                                num_dup=2, num_neg=4, batch_size=20)\n",
    "len(train_generator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/30\n",
      "102/102 [==============================] - 10s 99ms/step - loss: 1.5309\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.4986762775699343 - normalized_discounted_cumulative_gain@5(0): 0.5504466953195465 - mean_average_precision(0): 0.5154984774159246\n",
      "Epoch 2/30\n",
      "102/102 [==============================] - 8s 76ms/step - loss: 1.3919\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5549393331491589 - normalized_discounted_cumulative_gain@5(0): 0.6141207417094355 - mean_average_precision(0): 0.5726396260841911\n",
      "Epoch 3/30\n",
      "102/102 [==============================] - 8s 81ms/step - loss: 1.2867\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.573692471991038 - normalized_discounted_cumulative_gain@5(0): 0.6329777160881896 - mean_average_precision(0): 0.5860880510000691\n",
      "Epoch 4/30\n",
      "102/102 [==============================] - 9s 87ms/step - loss: 1.1994\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6033421839585815 - normalized_discounted_cumulative_gain@5(0): 0.6511131484452872 - mean_average_precision(0): 0.6095223789672819\n",
      "Epoch 5/30\n",
      "102/102 [==============================] - 9s 84ms/step - loss: 1.1051\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5991475474610819 - normalized_discounted_cumulative_gain@5(0): 0.6520339386707785 - mean_average_precision(0): 0.6064930533061078\n",
      "Epoch 6/30\n",
      "102/102 [==============================] - 8s 79ms/step - loss: 1.0113\n",
      "Epoch 6/30\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5919183117111824 - normalized_discounted_cumulative_gain@5(0): 0.6465694657465371 - mean_average_precision(0): 0.6026914849341061\n",
      "Epoch 7/30\n",
      "102/102 [==============================] - 9s 88ms/step - loss: 0.9285\n",
      "\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6019177810638232 - normalized_discounted_cumulative_gain@5(0): 0.6586556349475311 - mean_average_precision(0): 0.6148511548469406\n",
      "Epoch 8/30\n",
      "101/102 [============================>.] - ETA: 0s - loss: 0.8592\n",
      "Epoch 8/30\n",
      "102/102 [==============================] - 9s 86ms/step - loss: 0.8575\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5925760307863998 - normalized_discounted_cumulative_gain@5(0): 0.6495575879366299 - mean_average_precision(0): 0.6019260728414415\n",
      "Epoch 9/30\n",
      "102/102 [==============================] - 9s 85ms/step - loss: 0.8027\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6002159618486113 - normalized_discounted_cumulative_gain@5(0): 0.6576207137730721 - mean_average_precision(0): 0.613792722366319\n",
      "Epoch 10/30\n",
      "101/102 [============================>.] - ETA: 0s - loss: 0.7643Epoch 10/30\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6002159618486113 - normalized_discounted_cumulative_gain@5(0): 0.6576207137730721 - mean_average_precision(0): 0.613792722366319\n",
      "102/102 [==============================] - 8s 79ms/step - loss: 0.7645\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6032611122698778 - normalized_discounted_cumulative_gain@5(0): 0.6559739238278149 - mean_average_precision(0): 0.6110217698280545\n",
      "Epoch 11/30\n",
      "102/102 [==============================] - 9s 88ms/step - loss: 0.7317\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5925477525590856 - normalized_discounted_cumulative_gain@5(0): 0.6549344362715478 - mean_average_precision(0): 0.6070687281818817\n",
      "Epoch 12/30\n",
      "101/102 [============================>.] - ETA: 0s - loss: 0.7072\n",
      "102/102 [==============================] - 9s 86ms/step - loss: 0.7116\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5914723561389517 - normalized_discounted_cumulative_gain@5(0): 0.6541533130679608 - mean_average_precision(0): 0.6081659022220937\n",
      "Epoch 13/30\n",
      "102/102 [==============================] - 8s 83ms/step - loss: 0.6947\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5846953239145032 - normalized_discounted_cumulative_gain@5(0): 0.6484054609449952 - mean_average_precision(0): 0.600627964304409\n",
      "Epoch 14/30\n",
      "102/102 [==============================] - 9s 87ms/step - loss: 0.6813\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5918600152369855 - normalized_discounted_cumulative_gain@5(0): 0.6504547573842671 - mean_average_precision(0): 0.6085209371894704\n",
      "Epoch 14/30\n",
      "Epoch 15/30\n",
      "102/102 [==============================] - 10s 93ms/step - loss: 0.6699\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5891092775638822 - normalized_discounted_cumulative_gain@5(0): 0.6455536505470437 - mean_average_precision(0): 0.6035318803016794\n",
      "Epoch 16/30\n",
      "102/102 [==============================] - 8s 81ms/step - loss: 0.6599\n",
      "Epoch 16/30\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6009115392223514 - normalized_discounted_cumulative_gain@5(0): 0.6518797057497255 - mean_average_precision(0): 0.6075287039187561\n",
      "Epoch 17/30\n",
      "101/102 [============================>.] - ETA: 0s - loss: 0.6557Epoch 17/30\n",
      "\n",
      "102/102 [==============================] - 9s 87ms/step - loss: 0.6534\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5984631037335253 - normalized_discounted_cumulative_gain@5(0): 0.6539943205177465 - mean_average_precision(0): 0.6094837329727536\n",
      "Epoch 18/30\n",
      "102/102 [==============================] - 9s 83ms/step - loss: 0.6449\n",
      "\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5994042643738182 - normalized_discounted_cumulative_gain@5(0): 0.6555438249324966 - mean_average_precision(0): 0.6124138482609293\n",
      "Epoch 19/30\n",
      "102/102 [==============================] - 9s 84ms/step - loss: 0.6413\n",
      "\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.605947093667994 - normalized_discounted_cumulative_gain@5(0): 0.6536269149996055 - mean_average_precision(0): 0.6158320325731008\n",
      "Epoch 20/30\n",
      "101/102 [============================>.] - ETA: 0s - loss: 0.6342Epoch 20/30\n",
      "102/102 [==============================] - 9s 91ms/step - loss: 0.6360\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6043205182024866 - normalized_discounted_cumulative_gain@5(0): 0.6576467139572901 - mean_average_precision(0): 0.6176885855372487\n",
      "Epoch 21/30\n",
      "101/102 [============================>.] - ETA: 0s - loss: 0.6300Epoch 21/30\n",
      "102/102 [==============================] - 9s 84ms/step - loss: 0.6319\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.608241017664401 - normalized_discounted_cumulative_gain@5(0): 0.659326742713726 - mean_average_precision(0): 0.6175278819559088\n",
      "Epoch 22/30\n",
      "102/102 [==============================] - 9s 84ms/step - loss: 0.6259\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6105996485513213 - normalized_discounted_cumulative_gain@5(0): 0.655769142117799 - mean_average_precision(0): 0.6183184020897835\n",
      "Epoch 23/30\n",
      "102/102 [==============================] - 8s 82ms/step - loss: 0.6231\n",
      "Epoch 23/30\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6085797485118771 - normalized_discounted_cumulative_gain@5(0): 0.65737012582498 - mean_average_precision(0): 0.6185547925160473\n",
      "Epoch 24/30\n",
      "102/102 [==============================] - 9s 85ms/step - loss: 0.6188\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6067543434546379 - normalized_discounted_cumulative_gain@5(0): 0.6533856452279714 - mean_average_precision(0): 0.614064989466583\n",
      "Epoch 25/30\n",
      "102/102 [==============================] - 8s 81ms/step - loss: 0.6162\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6128889623964834 - normalized_discounted_cumulative_gain@5(0): 0.6595256434934039 - mean_average_precision(0): 0.6197743397705909\n",
      "Epoch 26/30\n",
      "102/102 [==============================] - 7s 65ms/step - loss: 0.6121\n",
      "Epoch 26/30\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6127241258924424 - normalized_discounted_cumulative_gain@5(0): 0.6610313052668069 - mean_average_precision(0): 0.6225225741690487\n",
      "Epoch 27/30\n",
      "102/102 [==============================] - 7s 65ms/step - loss: 0.6075\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6127241258924424 - normalized_discounted_cumulative_gain@5(0): 0.6610313052668069 - mean_average_precision(0): 0.6225225741690487\n",
      "\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6035089831907896 - normalized_discounted_cumulative_gain@5(0): 0.6604226041817534 - mean_average_precision(0): 0.6144724008894449\n",
      "Epoch 28/30\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "101/102 [============================>.] - ETA: 0s - loss: 0.6069Epoch 28/30\n",
      "102/102 [==============================] - 7s 68ms/step - loss: 0.6053\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6030962757658368 - normalized_discounted_cumulative_gain@5(0): 0.6550618042545611 - mean_average_precision(0): 0.6153120002258798\n",
      "Epoch 29/30\n",
      "102/102 [==============================] - 7s 67ms/step - loss: 0.5999ation: normalized_discounted_cumulative_gain@3(0): 0.6030962757658368 - normalized_discounted_cumulative_gain@5(0): 0.6550618042545611 - mean_average_precision(0): 0.615312000225879\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6095742357064031 - normalized_discounted_cumulative_gain@5(0): 0.6609072243676714 - mean_average_precision(0): 0.6193984247841957\n",
      "Epoch 30/30\n",
      "102/102 [==============================] - 7s 67ms/step - loss: 0.5962\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6097187964978661 - normalized_discounted_cumulative_gain@5(0): 0.662263044884704 - mean_average_precision(0): 0.6186700966708153\n"
     ]
    }
   ],
   "source": [
    "history = model.fit_generator(train_generator, epochs=30, callbacks=[evaluate], workers=30, use_multiprocessing=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 0.49939275],\n",
       "       [ 2.2709463 ],\n",
       "       [-3.1269906 ],\n",
       "       [-0.03658384],\n",
       "       [-2.9051058 ],\n",
       "       [-1.0289317 ],\n",
       "       [ 1.0633922 ],\n",
       "       [-0.64751583],\n",
       "       [ 1.0268258 ],\n",
       "       [ 1.7241758 ]], dtype=float32)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "drmm_model = mz.load_model('./drmm_pretrained_model/16')\n",
    "test_generator = mz.HistogramDataGenerator(data_pack=valid_pack_processed[:10],\n",
    "                                           embedding_matrix=embedding_matrix,\n",
    "                                           bin_size=bin_size, \n",
    "                                           hist_mode='LCH')\n",
    "test_x, test_y = test_generator[:]\n",
    "prediction = drmm_model.predict(test_x)\n",
    "prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import shutil\n",
    "shutil.rmtree('./drmm_pretrained_model/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
